// Copyright 2019 eBay Inc.
// Primary authors: Simon Fell, Diego Ongaro,
//                  Raymond Kroeker, and Sathish Kandasamy.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package search

import (
	"fmt"
	"io"
	"strconv"
	"strings"

	"github.com/ebay/akutan/util/bytes"
	"github.com/ebay/akutan/util/cmp"
	log "github.com/sirupsen/logrus"
)

// A Space represents the search space of a single query's equivalent logical
// expressions and physical implementation plans. In other words, it holds the
// explored plans for a single query.
//
// Space provides a somewhat low-level interface to the query planner; for a
// simpler interface, see Prepare.
//
// The methods on Space are also likely to change in the future. Right now,
// exploration, implementation, and costing are done as discrete steps. The
// Cascades/Columbia papers argue for running some implementation rules and
// predicting some costs earlier in the process.
//
// The core methods of Space must be invoked in this exact order:
//	  space := NewSpace(expr, rules)
//	  space.Explore()
//	  space.Implement()
//	  space.PredictCosts(estimator)
//	  return space.BestPlan()
// In between these calls, however, you can invoke String() or Graphviz() to
// get more information.
type Space struct {
	// Implementation notes
	//
	// The Space's possible expression trees are represented as a DAG of groups
	// (Group), where each group is a set of logically equivalent expressions
	// (Expr). The Root group sits at the top level and contains expressions that
	// are logically equivalent to the input query. Each expression takes zero or
	// more groups as input; this is a compact way to represent that the parent Expr
	// can take as input any Expr from each of those groups. Note that each group
	// may serve as input to multiple expressions in multiple higher-level groups.
	//
	// A Space starts with a single expression tree (the input to the query planner),
	// where each group consists of a single expression. Then, repeatedly, the query
	// planner finds an expression in the DAG and proposes an equivalent expression
	// for the same group. This equivalent expression gets integrated into the Space:
	// either its inputs already exist in lower-level groups, or new groups are
	// created. The new expression itself goes into the same group as the original,
	// since the query planner is asserting that they are logically equivalent.
	//
	// Depending on the search algorithm and rule set, it's possible or likely for
	// the same expression to be derived multiple ways. However, each expression
	// will only appear once in the DAG. This is a key feature of this data
	// structure to avoid duplicating work. An expression's identity is defined by
	// an operator and a list of input group IDs. A hash table allows looking up an
	// expression in the DAG from this identity; it's used to avoid inserting
	// duplicate expressions into the DAG.
	//
	// If an expression is derived as logically equivalent to one group in the DAG,
	// but that expression is found to already exist in another group, this implies
	// that the two groups are logically equivalent. They are immediately merged
	// together and the duplicate expression is discarded. This can also affect the
	// parent groups that used the merged group as a direct input: expressions
	// previously believed to be distinct may also merge.

	// The top-level group whose expressions are logically equivalent to the input
	// query. Unless otherwise stated, this field and everything accessible through
	// it are read-only outside this package.
	root *Group
	// Every expression is present in this hash table. The map key is generated by
	// Expr.Hash().
	exprHash map[string]*Expr
	// A counter to be assigned to the next Group to be created, then incremented.
	nextGroupID int
	// A counter to be assigned to the next Expr to be created, then incremented.
	nextExprID int

	def     Definition
	options Options
}

// A Group is a set of logically equivalent expressions.
type Group struct {
	// A unique identifier for the Group within the Plan, which is useful for
	// human-readable output.
	ID int
	// The set of logically equivalent expressions that make up this Group, in no
	// particular order. Groups reachable from the root of the DAG will have
	// a non-empty set of expressions. Unless otherwise stated, this field and
	// everything accessible through it are read-only outside this package.
	Exprs []*Expr
	// The query planner may write to this field to indicate which of the
	// expressions has the lowest cost.
	Best *Expr
	// User-defined properties that are true for every Expr in this Group.
	LogicalProp LogicalProperties
	// If non-nil, this Group has been merged into another group. This implies
	// this Group will no longer be reachable from the root of the Plan, but the
	// caller may have kept a Group pointer across a call that merged Groups.
	mergedInto *Group
}

// An Expr bundles a logical or physical operator along with its input groups.
type Expr struct {
	// Logical or physical operator.
	Operator Operator
	// The operator reads from these to generate its output.
	Inputs []*Group
	// The cost of this algorithm, excluding the costs of its inputs.
	LocalCost Cost
	// A lower bound on the cost of this expression (including the costs of its
	// inputs).
	CombinedCost Cost
	// The expression's equivalence class.
	Group *Group
	// A Logical identifier for this Expression. Every Expr instance in the memo
	// should have a unique id. Stale expressions may have an Id that is also in
	// use in the memo. Used to avoid exploring the same expression twice, even
	// if it's been moved to a different group.
	id int
	// If empty, this Expr is valid and part of the memo structure. If not empty
	// contains the reason why this Expr should be considered stale. If stale,
	// then this instance is no longer in the memo structure.
	staleReason string
}

// isStale returns true if this Expr instance is stale, i.e. its no longer in
// the memo. For code that gets an Expr then mutates the memo, it should check
// that the Expr is not stale after the mutation if it plans to use it again.
func (expr *Expr) isStale() bool {
	return expr.staleReason != ""
}

// groups returns a list of distinct Groups in bottom-up order. The list is
// topologically sorted, where A precedes B in the list if A is an input to B.
func (space *Space) groups() []*Group {
	// This is a post-order DFS of the groups, starting at the root.
	// Distinct groups are appended to 'list'.

	// This gets called a lot during exploration. We can use nextGroupID as a
	// hint to how many groups we're going to find.
	guessCount := cmp.MinInt(256, space.nextGroupID)
	list := make([]*Group, 0, guessCount)
	// 'set' is a set of group IDs, used to check whether we've already visited a
	// group.
	set := make(map[int]struct{}, guessCount)
	var add func(group *Group)
	add = func(group *Group) {
		if _, found := set[group.ID]; found {
			return
		}
		set[group.ID] = struct{}{}
		for _, expr := range group.Exprs {
			for _, input := range expr.Inputs {
				add(input)
			}
		}
		list = append(list, group)
	}
	add(space.root)
	return list
}

// Returns a multi-line human-readable description of the groups in the search space.
func (space *Space) String() string {
	var b strings.Builder
	for _, group := range space.groups() {
		fmt.Fprintf(&b, "Group %v", group.ID)
		if group.LogicalProp != nil {
			fmt.Fprintf(&b, " [%+v]", group.LogicalProp)
		}
		fmt.Fprintf(&b, "\n")
		for _, expr := range group.Exprs {
			fmt.Fprintf(&b, "\t%v\n", expr)
		}
	}
	return b.String()
}

// DebugCostedBest will write the details of the selected plan including the
// costing information to the supplied StringWriter.
func (space *Space) DebugCostedBest(w bytes.StringWriter) {
	// Find the longest length operator string including the indenting. We use
	// the result to line up the costing info across all the lines.
	var groupLen func(depth int, g *Group) int
	groupLen = func(depth int, g *Group) int {
		if g.Best == nil {
			return 0
		}
		l := (depth * 4) + len(g.Best.Operator.String())
		for _, i := range g.Best.Inputs {
			l = cmp.MaxInt(l, groupLen(depth+1, i))
		}
		return l
	}
	maxLen := groupLen(0, space.root) + 1

	var print func(depth int, g *Group)
	print = func(depth int, g *Group) {
		if g.Best == nil {
			fmt.Fprintf(w, "Group %d has no best Expr set\n", g.ID)
			return
		}
		fmt.Fprintf(w, "%s%s%s costs local %v combined %v logicalProps: %v\n",
			strings.Repeat(" ", depth*4),
			g.Best.Operator,
			strings.Repeat(" ", maxLen-(depth*4)-len(g.Best.Operator.String())),
			g.Best.LocalCost, g.Best.CombinedCost,
			g.LogicalProp.DetailString())
		for _, i := range g.Best.Inputs {
			print(depth+1, i)
		}
	}
	print(0, space.root)
}

// Debug will write detailed information about the current state of the space to
// the supplied StringWriter. It includes more details than what String() does.
func (space *Space) Debug(w bytes.StringWriter) {
	// collect up all the groups that are selected in the final plan
	planGroups := make(map[int]struct{})
	var add func(g *Group)
	add = func(g *Group) {
		planGroups[g.ID] = struct{}{}
		if g.Best != nil {
			for _, i := range g.Best.Inputs {
				add(i)
			}
		}
	}
	add(space.root)
	groups := space.groups()
	// groups() returns the leaves first, and the root last, but mostly when reviewing
	// spaces, you start at the root and work down, so we flip the order.
	for i := range groups {
		group := groups[len(groups)-1-i]
		fmt.Fprintf(w, "Group %v", group.ID)
		if group.LogicalProp != nil {
			fmt.Fprintf(w, " [%s]", group.LogicalProp.DetailString())
		}
		fmt.Fprintln(w)
		maxExprLen := 30 // allow at least 30 chars for the operator
		for _, expr := range group.Exprs {
			maxExprLen = cmp.MaxInt(maxExprLen, len(expr.String()))
		}
		// round up to be a multiple of 10, to increase odds of alignment
		// across groups.
		if maxExprLen%10 != 0 {
			maxExprLen += 10 - maxExprLen%10
		}
		for _, expr := range group.Exprs {
			s := expr.String()
			fmt.Fprintf(w, "\t%s", s)
			if expr.LocalCost != nil {
				fmt.Fprintf(w, "%s costs local %v combined %v",
					strings.Repeat(" ", maxExprLen+2-len(s)),
					expr.LocalCost, expr.CombinedCost)
				if expr == group.Best {
					fmt.Fprint(w, " [best")
					if _, exists := planGroups[group.ID]; exists {
						fmt.Fprint(w, ",selected")
					}
					fmt.Fprint(w, "]")
				}
			}
			fmt.Fprintln(w)
		}
	}
}

// Graphviz writes a Graphviz dot-formatted description of the search space.
// Note: it ignores errors in writing to w.
func (space *Space) Graphviz(w io.Writer) {
	groups := space.groups()
	fmt.Fprintf(w, "digraph {\n")
	fmt.Fprintf(w, "  ranksep=1.5;\n")
	for _, group := range groups {
		fmt.Fprintf(w, "  subgraph cluster_%v {\n", group.ID)
		fmt.Fprintf(w, "    bgcolor=aliceblue;\n")
		fmt.Fprintf(w, "    group_%v [label=\"Group %v\n%v\", shape=triangle];\n",
			group.ID, group.ID, group.LogicalProp.DetailString())
		for j, expr := range group.Exprs {
			fmt.Fprintf(w, "    group_%v -> expr_%v_%v [label=\"%v\"];\n", group.ID, group.ID, j, j)
			color := "aliceblue"
			if expr == group.Best {
				color = "green"
			}
			if expr.CombinedCost == nil || expr.CombinedCost.Infinite() {
				color = "gray"
			}
			fmt.Fprintf(w, "    expr_%v_%v [label=\"%v\ncost: %v\",shape=rectangle,style=filled,fillcolor=%v];\n",
				group.ID, j, expr, expr.CombinedCost, color)
		}
		fmt.Fprintf(w, "  } // subgraph cluster_%v \n", group.ID)
	}
	// Draw edges in a second pass so that they don't implicitly declare the other
	// group's node in the wrong subgraph.
	for _, group := range groups {
		for j, expr := range group.Exprs {
			for _, input := range expr.Inputs {
				fmt.Fprintf(w, "    expr_%v_%v -> group_%v;\n", group.ID, j, input.ID)
			}
		}
	}
	fmt.Fprintf(w, "} // digraph\n")
}

const patternSpace = " \t"

func leftSpace(str string) string {
	for i, r := range str {
		if !strings.ContainsRune(patternSpace, r) {
			return str[:i]
		}
	}
	return str
}

type patternTree struct {
	op     string
	inputs []*patternTree
}

// Reads the given multi-line indented pattern string into a tree. Panics on
// errors.
func parsePattern(pattern string) *patternTree {
	var lines []string
	for _, line := range strings.Split(pattern, "\n") {
		line = strings.TrimRight(line, patternSpace)
		if len(strings.TrimLeft(line, patternSpace)) != 0 {
			lines = append(lines, line)
		}
	}
	if len(lines) == 0 {
		return nil
	}
	// Consume lines starting with 'prefix' into a new pattern node.
	var consume func(prefix string) *patternTree
	consume = func(prefix string) *patternTree {
		node := &patternTree{
			op: lines[0][len(prefix):],
		}
		lines = lines[1:]
		for len(lines) > 0 {
			if !strings.HasPrefix(lines[0], prefix) {
				return node
			}
			rest := lines[0][len(prefix):]
			additional := leftSpace(rest)
			if len(additional) == 0 {
				return node
			}
			node.inputs = append(node.inputs, consume(prefix+additional))
		}
		return node
	}
	root := consume(leftSpace(lines[0]))
	if len(lines) > 0 {
		log.Panicf("Malformed pattern indentation")
	}
	return root
}

// matches returns true if expr's operator matches op and its inputs match the
// given group IDs. It's a helper to Contains.
func matches(expr *Expr, op string, inputs []int) bool {
	if op != expr.Operator.String() {
		return false
	}
	if len(inputs) != len(expr.Inputs) {
		return false
	}
	for i := range inputs {
		if inputs[i] != expr.Inputs[i].ID {
			return false
		}
	}
	return true
}

// Contains looks for the given expression tree in the space. It returns true if
// it found a matching expression tree, false otherwise. Contains is intended
// for unit tests only. It panics if the pattern isn't indented correctly.
//
// pattern should be a multi-line indented string like this:
//     Join
//         Select
//             Scan A
//         Join
//             Scan B
//             Scan C
// The lines must match the String() values of operators in the search space exactly.
func (space *Space) Contains(pattern string) bool {
	tree := parsePattern(pattern)
	var find func(node *patternTree) *Expr
	find = func(node *patternTree) *Expr {
		var groupIDs []int
		for _, input := range node.inputs {
			expr := find(input)
			if expr == nil {
				return nil
			}
			groupIDs = append(groupIDs, expr.Group.ID)
		}
		for _, expr := range space.exprHash {
			if matches(expr, node.op, groupIDs) {
				return expr
			}
		}
		log.Printf("Couldn't find %v with inputs %v", node.op, groupIDs)
		return nil
	}
	return find(tree) != nil
}

// hasInput returns true if any of this expr's direct inputs are the supplied
// group ID.
func (expr *Expr) hasInput(groupID int) bool {
	for _, in := range expr.Inputs {
		if in.ID == groupID {
			return true
		}
	}
	return false
}

// Key implements cmp.Key. It writes the identity of the expression to the given
// strings.Builder: its operator's key and its input group IDs. This package uses the
// Key method to track expressions in a map; it's probably not useful outside
// this package but is public to implement util/cmp.Key.
func (expr *Expr) Key(b *strings.Builder) {
	b.Grow(128)
	expr.Operator.Key(b)
	if len(expr.Inputs) > 0 {
		b.WriteByte(' ')
		b.WriteByte('[')
		for i := range expr.Inputs {
			b.WriteString(strconv.Itoa(expr.Inputs[i].ID))
			if i < len(expr.Inputs)-1 {
				b.WriteByte(' ')
			}
		}
		b.WriteByte(']')
	}
}

// Returns a single-line string describing the operator and its input group IDs.
func (expr *Expr) String() string {
	var b strings.Builder
	fmt.Fprintf(&b, "%v", expr.Operator)
	if len(expr.Inputs) > 0 {
		fmt.Fprintf(&b, " [")
		for i := range expr.Inputs {
			fmt.Fprintf(&b, "%v", expr.Inputs[i].ID)
			if i < len(expr.Inputs)-1 {
				fmt.Fprintf(&b, " ")
			}
		}
		fmt.Fprintf(&b, "]")
	}
	return b.String()
}

// MustCheckInvariants runs CheckInvariants and panics if there are any issues.
func (space *Space) MustCheckInvariants() {
	err := space.CheckInvariants()
	if err != nil {
		log.Panicf("Plan space is corrupt: %v", err)
	}
}

// CheckInvariants runs internal consistency checks on the space data structure.
// It was initially used for testing and debugging but may also be called less
// frequently during runtime.
func (space *Space) CheckInvariants() error {
	groups := space.groups()

	// Check that all groups have expressions.
	for _, group := range groups {
		if len(group.Exprs) == 0 {
			return fmt.Errorf("reachable group %v has no expressions:\n%v", group.ID, space)
		}
	}

	// Check that all groups' expressions have the correct group, and are not flagged as stale.
	// Check that expr Ids are unique
	exprIds := make(map[int]*Expr, len(space.exprHash))
	for _, group := range groups {
		for _, expr := range group.Exprs {
			if expr.Group != group {
				return fmt.Errorf("expr %v %v has group %v but found in group %v:\n%v",
					expr.Operator, expr.Inputs, expr.Group, group.ID, space)
			}
			if expr.isStale() {
				return fmt.Errorf("expr should not be stale: expr %d %v in group %d is stale because: %v",
					expr.id, expr, group.ID, expr.staleReason)
			}
			if prev, exists := exprIds[expr.id]; exists {
				return fmt.Errorf("expr %v has id %d, but expr %v has the same id", expr, expr.id, prev)
			}
			exprIds[expr.id] = expr
		}
	}

	// Check that everything in the hash table matches its expression.
	for key, expr := range space.exprHash {
		if key != cmp.GetKey(expr) {
			return fmt.Errorf("expr %v %v found with key %v instead of %v:\n%v",
				expr.Operator, expr.Inputs, key, cmp.GetKey(expr), space)
		}
	}

	// Check that all reachable groups' expressions are in the hash table and vice versa.
	numExprs := 0
	for _, group := range groups {
		for _, expr := range group.Exprs {
			numExprs++
			_, found := space.exprHash[cmp.GetKey(expr)]
			if !found {
				return fmt.Errorf("expr %v %v not found in hash table:\n%v",
					expr.Operator, expr.Inputs, space)
			}
		}
	}
	if len(space.exprHash) != numExprs {
		return fmt.Errorf("hash table has %v exprs, but groups have %v:\n%v",
			len(space.exprHash), numExprs, space)
	}

	return nil
}
